{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Tracr_to_Transformer_Lens_Demo.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tracr to TransformerLens Converter\n",
    "[Tracr](https://github.com/deepmind/tracr) is a cool new DeepMind tool that compiles a written program in RASP to transformer weights. TransformerLens is a library I've written to easily do mechanistic interpretability on a transformer and to poke around at its internals. This is a (hacky!) script to convert Tracr weights from the JAX form to a TransformerLens HookedTransformer in PyTorch.\n",
    "\n",
    "See [the TransformerLens tutorial](https://neelnanda.io/transformer-lens-demo) to get started"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Python version must be >=3.8 (my fork of Tracr is a bit more backwards compatible, original library is at least 3.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Python 3.8.15\n"
     ]
    }
   ],
   "source": [
    "!python --version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running as a Jupyter notebook - intended for development only!\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    import google.colab\n",
    "    IN_COLAB = True\n",
    "    print(\"Running as a Colab notebook\")\n",
    "    %pip install transformer_lens\n",
    "    # Fork of Tracr that's backward compatible with Python 3.8\n",
    "    %pip install git+https://github.com/neelnanda-io/Tracr\n",
    "    \n",
    "except:\n",
    "    IN_COLAB = False\n",
    "    print(\"Running as a Jupyter notebook - intended for development only!\")\n",
    "    # from IPython import get_ipython\n",
    "\n",
    "    # ipython = get_ipython()\n",
    "    # # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
    "    # ipython.magic(\"load_ext autoreload\")\n",
    "    # ipython.magic(\"autoreload 2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/python38/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from transformer_lens import HookedTransformer, HookedTransformerConfig\n",
    "import einops\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from tracr.rasp import rasp\n",
    "from tracr.compiler import compiling"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Loads an example RASP program model. This program reverses lists. The model takes as input a list of pre-tokenization elements (here `[\"BOS\", 1, 2, 3]`), these are tokenized (`[3, 0, 1, 2]`), the transformer is applied, and then an argmax is taken over the output and it is detokenized - this can be seen on the `out.decoded` attribute of the output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:Creating a SequenceMap with both inputs being the same SOp is discouraged. You should use a Map instead.\n",
      "WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "def make_length():\n",
    "  all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)\n",
    "  return rasp.SelectorWidth(all_true_selector)\n",
    "\n",
    "\n",
    "length = make_length()  # `length` is not a primitive in our implementation.\n",
    "opp_index = length - rasp.indices - 1\n",
    "flip = rasp.Select(rasp.indices, opp_index, rasp.Comparison.EQ)\n",
    "reverse = rasp.Aggregate(flip, rasp.tokens)\n",
    "\n",
    "bos = \"BOS\"\n",
    "model = compiling.compile_rasp_to_model(\n",
    "    reverse,\n",
    "    vocab={1, 2, 3},\n",
    "    max_seq_len=5,\n",
    "    compiler_bos=bos,\n",
    ")\n",
    "\n",
    "out = model.apply([bos, 1, 2, 3])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Extract the model config from the Tracr model, and create a blank HookedTransformer object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# %%\n",
    "\n",
    "n_heads = model.model_config.num_heads\n",
    "n_layers = model.model_config.num_layers\n",
    "d_head = model.model_config.key_size\n",
    "d_mlp = model.model_config.mlp_hidden_size\n",
    "act_fn = \"relu\"\n",
    "normalization_type = \"LN\"  if model.model_config.layer_norm else None\n",
    "attention_type = \"causal\"  if model.model_config.causal else \"bidirectional\"\n",
    "\n",
    "\n",
    "n_ctx = model.params[\"pos_embed\"]['embeddings'].shape[0]\n",
    "# Equivalent to length of vocab, with BOS and PAD at the end\n",
    "d_vocab = model.params[\"token_embed\"]['embeddings'].shape[0]\n",
    "# Residual stream width, I don't know of an easy way to infer it from the above config.\n",
    "d_model = model.params[\"token_embed\"]['embeddings'].shape[1]\n",
    "\n",
    "# Equivalent to length of vocab, WITHOUT BOS and PAD at the end because we never care about these outputs\n",
    "# In practice, we always feed the logits into an argmax\n",
    "d_vocab_out = model.params[\"token_embed\"]['embeddings'].shape[0] - 2\n",
    "\n",
    "cfg = HookedTransformerConfig(\n",
    "    n_layers=n_layers,\n",
    "    d_model=d_model,\n",
    "    d_head=d_head,\n",
    "    n_ctx=n_ctx,\n",
    "    d_vocab=d_vocab,\n",
    "    d_vocab_out=d_vocab_out,\n",
    "    d_mlp=d_mlp,\n",
    "    n_heads=n_heads,\n",
    "    act_fn=act_fn,\n",
    "    attention_dir=attention_type,\n",
    "    normalization_type=normalization_type,\n",
    ")\n",
    "tl_model = HookedTransformer(cfg)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Extract the state dict, and do some reshaping so that everything has a n_heads dimension"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['pos_embed.W_pos', 'embed.W_E', 'unembed.W_U', 'blocks.0.attn.W_K', 'blocks.0.attn.b_K', 'blocks.0.attn.W_Q', 'blocks.0.attn.b_Q', 'blocks.0.attn.W_V', 'blocks.0.attn.b_V', 'blocks.0.attn.W_O', 'blocks.0.attn.b_O', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out', 'blocks.1.attn.W_K', 'blocks.1.attn.b_K', 'blocks.1.attn.W_Q', 'blocks.1.attn.b_Q', 'blocks.1.attn.W_V', 'blocks.1.attn.b_V', 'blocks.1.attn.W_O', 'blocks.1.attn.b_O', 'blocks.1.mlp.W_in', 'blocks.1.mlp.b_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_out', 'blocks.2.attn.W_K', 'blocks.2.attn.b_K', 'blocks.2.attn.W_Q', 'blocks.2.attn.b_Q', 'blocks.2.attn.W_V', 'blocks.2.attn.b_V', 'blocks.2.attn.W_O', 'blocks.2.attn.b_O', 'blocks.2.mlp.W_in', 'blocks.2.mlp.b_in', 'blocks.2.mlp.W_out', 'blocks.2.mlp.b_out', 'blocks.3.attn.W_K', 'blocks.3.attn.b_K', 'blocks.3.attn.W_Q', 'blocks.3.attn.b_Q', 'blocks.3.attn.W_V', 'blocks.3.attn.b_V', 'blocks.3.attn.W_O', 'blocks.3.attn.b_O', 'blocks.3.mlp.W_in', 'blocks.3.mlp.b_in', 'blocks.3.mlp.W_out', 'blocks.3.mlp.b_out'])\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# %%\n",
    "sd = {}\n",
    "sd[\"pos_embed.W_pos\"] = model.params[\"pos_embed\"]['embeddings']\n",
    "sd[\"embed.W_E\"] = model.params[\"token_embed\"]['embeddings']\n",
    "# Equivalent to max_seq_len plus one, for the BOS\n",
    "\n",
    "# The unembed is just a projection onto the first few elements of the residual stream, these store output tokens\n",
    "# This is a NumPy array, the rest are Jax Arrays, but w/e it's fine.\n",
    "sd[\"unembed.W_U\"] = np.eye(d_model, d_vocab_out)\n",
    "\n",
    "for l in range(n_layers):\n",
    "    sd[f\"blocks.{l}.attn.W_K\"] = einops.rearrange(\n",
    "        model.params[f\"transformer/layer_{l}/attn/key\"][\"w\"],\n",
    "        \"d_model (n_heads d_head) -> n_heads d_model d_head\",\n",
    "        d_head = d_head,\n",
    "        n_heads = n_heads\n",
    "    )\n",
    "    sd[f\"blocks.{l}.attn.b_K\"] = einops.rearrange(\n",
    "        model.params[f\"transformer/layer_{l}/attn/key\"][\"b\"],\n",
    "        \"(n_heads d_head) -> n_heads d_head\",\n",
    "        d_head = d_head,\n",
    "        n_heads = n_heads\n",
    "    )\n",
    "    sd[f\"blocks.{l}.attn.W_Q\"] = einops.rearrange(\n",
    "        model.params[f\"transformer/layer_{l}/attn/query\"][\"w\"],\n",
    "        \"d_model (n_heads d_head) -> n_heads d_model d_head\",\n",
    "        d_head = d_head,\n",
    "        n_heads = n_heads\n",
    "    )\n",
    "    sd[f\"blocks.{l}.attn.b_Q\"] = einops.rearrange(\n",
    "        model.params[f\"transformer/layer_{l}/attn/query\"][\"b\"],\n",
    "        \"(n_heads d_head) -> n_heads d_head\",\n",
    "        d_head = d_head,\n",
    "        n_heads = n_heads\n",
    "    )\n",
    "    sd[f\"blocks.{l}.attn.W_V\"] = einops.rearrange(\n",
    "        model.params[f\"transformer/layer_{l}/attn/value\"][\"w\"],\n",
    "        \"d_model (n_heads d_head) -> n_heads d_model d_head\",\n",
    "        d_head = d_head,\n",
    "        n_heads = n_heads\n",
    "    )\n",
    "    sd[f\"blocks.{l}.attn.b_V\"] = einops.rearrange(\n",
    "        model.params[f\"transformer/layer_{l}/attn/value\"][\"b\"],\n",
    "        \"(n_heads d_head) -> n_heads d_head\",\n",
    "        d_head = d_head,\n",
    "        n_heads = n_heads\n",
    "    )\n",
    "    sd[f\"blocks.{l}.attn.W_O\"] = einops.rearrange(\n",
    "        model.params[f\"transformer/layer_{l}/attn/linear\"][\"w\"],\n",
    "        \"(n_heads d_head) d_model -> n_heads d_head d_model\",\n",
    "        d_head = d_head,\n",
    "        n_heads = n_heads\n",
    "    )\n",
    "    sd[f\"blocks.{l}.attn.b_O\"] = model.params[f\"transformer/layer_{l}/attn/linear\"][\"b\"]\n",
    "\n",
    "    sd[f\"blocks.{l}.mlp.W_in\"] = model.params[f\"transformer/layer_{l}/mlp/linear_1\"][\"w\"]\n",
    "    sd[f\"blocks.{l}.mlp.b_in\"] = model.params[f\"transformer/layer_{l}/mlp/linear_1\"][\"b\"]\n",
    "    sd[f\"blocks.{l}.mlp.W_out\"] = model.params[f\"transformer/layer_{l}/mlp/linear_2\"][\"w\"]\n",
    "    sd[f\"blocks.{l}.mlp.b_out\"] = model.params[f\"transformer/layer_{l}/mlp/linear_2\"][\"b\"]\n",
    "print(sd.keys())\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert weights to tensors and load into the tl_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_IncompatibleKeys(missing_keys=['blocks.0.attn.mask', 'blocks.0.attn.IGNORE', 'blocks.1.attn.mask', 'blocks.1.attn.IGNORE', 'blocks.2.attn.mask', 'blocks.2.attn.IGNORE', 'blocks.3.attn.mask', 'blocks.3.attn.IGNORE', 'unembed.b_U'], unexpected_keys=[])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "for k, v in sd.items():\n",
    "    # I cannot figure out a neater way to go from a Jax array to a numpy array lol\n",
    "    sd[k] = torch.tensor(np.array(v))\n",
    "\n",
    "tl_model.load_state_dict(sd, strict=False)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create helper functions to do the tokenization and de-tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# %%\n",
    "INPUT_ENCODER = model.input_encoder\n",
    "OUTPUT_ENCODER = model.output_encoder\n",
    "\n",
    "def create_model_input(input, input_encoder=INPUT_ENCODER):\n",
    "    encoding = input_encoder.encode(input)\n",
    "    return torch.tensor(encoding).unsqueeze(dim=0)\n",
    "\n",
    "def decode_model_output(logits, output_encoder=OUTPUT_ENCODER, bos_token=INPUT_ENCODER.bos_token):\n",
    "    max_output_indices = logits.squeeze(dim=0).argmax(dim=-1)\n",
    "    decoded_output = output_encoder.decode(max_output_indices.tolist())\n",
    "    decoded_output_with_bos = [bos_token] + decoded_output[1:]\n",
    "    return decoded_output_with_bos\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now run the model!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original Decoding: ['BOS', 3, 2, 1]\n",
      "TransformerLens Replicated Decoding: ['BOS', 3, 2, 1]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "input = [bos, 1, 2, 3]\n",
    "out = model.apply(input)\n",
    "print(\"Original Decoding:\", out.decoded)\n",
    "\n",
    "input_tokens_tensor = create_model_input(input)\n",
    "logits = tl_model(input_tokens_tensor)\n",
    "decoded_output = decode_model_output(logits)\n",
    "print(\"TransformerLens Replicated Decoding:\", decoded_output)\n",
    "# %%\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets cache all intermediate activations in the model, and check that they're the same:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Layer 0 Attn Out Equality Check: True\n",
      "Layer 0 MLP Out Equality Check: True\n",
      "Layer 1 Attn Out Equality Check: True\n",
      "Layer 1 MLP Out Equality Check: True\n",
      "Layer 2 Attn Out Equality Check: True\n",
      "Layer 2 MLP Out Equality Check: True\n",
      "Layer 3 Attn Out Equality Check: True\n",
      "Layer 3 MLP Out Equality Check: True\n"
     ]
    }
   ],
   "source": [
    "logits, cache = tl_model.run_with_cache(input_tokens_tensor)\n",
    "\n",
    "for layer in range(tl_model.cfg.n_layers):\n",
    "    print(f\"Layer {layer} Attn Out Equality Check:\", np.isclose(cache[\"attn_out\", layer].detach().cpu().numpy(), np.array(out.layer_outputs[2*layer])).all())\n",
    "    print(f\"Layer {layer} MLP Out Equality Check:\", np.isclose(cache[\"mlp_out\", layer].detach().cpu().numpy(), np.array(out.layer_outputs[2*layer+1])).all())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Look how pretty and ordered the final residual stream is!\n",
    "\n",
    "(The logits are the first 3 dimensions of the residual stream, and we can see that they're flipped!)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<html>\n",
       "<head><meta charset=\"utf-8\" /></head>\n",
       "<body>\n",
       "    <div>            <script src=\"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_SVG\"></script><script type=\"text/javascript\">if (window.MathJax && window.MathJax.Hub && window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}</script>                <script type=\"text/javascript\">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>\n",
       "        <script src=\"https://cdn.plot.ly/plotly-2.17.1.min.js\"></script>                <div id=\"966ceaba-5cf4-47a0-853d-6aa7f9a77be2\" class=\"plotly-graph-div\" style=\"height:525px; width:100%;\"></div>            <script type=\"text/javascript\">                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById(\"966ceaba-5cf4-47a0-853d-6aa7f9a77be2\")) {                    Plotly.newPlot(                        \"966ceaba-5cf4-47a0-853d-6aa7f9a77be2\",                        [{\"coloraxis\":\"coloraxis\",\"name\":\"0\",\"y\":[\"BOS\",\"1\",\"2\",\"3\"],\"z\":[[5.388877752920962e-07,5.388877752920962e-07,5.388877752920962e-07,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0],[2.9040084139481115e-13,2.9040084139481115e-13,0.9999994039535522,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.25,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0],[2.9040084139481115e-13,0.9999994039535522,2.9040084139481115e-13,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.25,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0],[0.9999994039535522,2.9040084139481115e-13,2.9040084139481115e-13,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.25,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0]],\"type\":\"heatmap\",\"xaxis\":\"x\",\"yaxis\":\"y\",\"hovertemplate\":\"Residual Stream: %{x}<br>Position: %{y}<br>color: %{z}<extra></extra>\"}],                        {\"template\":{\"data\":{\"histogram2dcontour\":[{\"type\":\"histogram2dcontour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"choropleth\":[{\"type\":\"choropleth\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"histogram2d\":[{\"type\":\"histogram2d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmap\":[{\"type\":\"heatmap\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmapgl\":[{\"type\":\"heatmapgl\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"contourcarpet\":[{\"type\":\"contourcarpet\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"contour\":[{\"type\":\"contour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"surface\":[{\"type\":\"surface\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"mesh3d\":[{\"type\":\"mesh3d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"scatter\":[{\"fillpattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2},\"type\":\"scatter\"}],\"parcoords\":[{\"type\":\"parcoords\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolargl\":[{\"type\":\"scatterpolargl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"bar\":[{\"error_x\":{\"color\":\"#2a3f5f\"},\"error_y\":{\"color\":\"#2a3f5f\"},\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"bar\"}],\"scattergeo\":[{\"type\":\"scattergeo\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolar\":[{\"type\":\"scatterpolar\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"histogram\":[{\"marker\":{\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"histogram\"}],\"scattergl\":[{\"type\":\"scattergl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatter3d\":[{\"type\":\"scatter3d\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattermapbox\":[{\"type\":\"scattermapbox\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterternary\":[{\"type\":\"scatterternary\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattercarpet\":[{\"type\":\"scattercarpet\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"carpet\":[{\"aaxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"baxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"type\":\"carpet\"}],\"table\":[{\"cells\":{\"fill\":{\"color\":\"#EBF0F8\"},\"line\":{\"color\":\"white\"}},\"header\":{\"fill\":{\"color\":\"#C8D4E3\"},\"line\":{\"color\":\"white\"}},\"type\":\"table\"}],\"barpolar\":[{\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"barpolar\"}],\"pie\":[{\"automargin\":true,\"type\":\"pie\"}]},\"layout\":{\"autotypenumbers\":\"strict\",\"colorway\":[\"#636efa\",\"#EF553B\",\"#00cc96\",\"#ab63fa\",\"#FFA15A\",\"#19d3f3\",\"#FF6692\",\"#B6E880\",\"#FF97FF\",\"#FECB52\"],\"font\":{\"color\":\"#2a3f5f\"},\"hovermode\":\"closest\",\"hoverlabel\":{\"align\":\"left\"},\"paper_bgcolor\":\"white\",\"plot_bgcolor\":\"#E5ECF6\",\"polar\":{\"bgcolor\":\"#E5ECF6\",\"angularaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"radialaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"ternary\":{\"bgcolor\":\"#E5ECF6\",\"aaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"baxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"caxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"coloraxis\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"colorscale\":{\"sequential\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"sequentialminus\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"diverging\":[[0,\"#8e0152\"],[0.1,\"#c51b7d\"],[0.2,\"#de77ae\"],[0.3,\"#f1b6da\"],[0.4,\"#fde0ef\"],[0.5,\"#f7f7f7\"],[0.6,\"#e6f5d0\"],[0.7,\"#b8e186\"],[0.8,\"#7fbc41\"],[0.9,\"#4d9221\"],[1,\"#276419\"]]},\"xaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"yaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"scene\":{\"xaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"yaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"zaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2}},\"shapedefaults\":{\"line\":{\"color\":\"#2a3f5f\"}},\"annotationdefaults\":{\"arrowcolor\":\"#2a3f5f\",\"arrowhead\":0,\"arrowwidth\":1},\"geo\":{\"bgcolor\":\"white\",\"landcolor\":\"#E5ECF6\",\"subunitcolor\":\"white\",\"showland\":true,\"showlakes\":true,\"lakecolor\":\"white\"},\"title\":{\"x\":0.05},\"mapbox\":{\"style\":\"light\"}}},\"xaxis\":{\"anchor\":\"y\",\"domain\":[0.0,1.0],\"scaleanchor\":\"y\",\"constrain\":\"domain\",\"title\":{\"text\":\"Residual Stream\"}},\"yaxis\":{\"anchor\":\"x\",\"domain\":[0.0,1.0],\"autorange\":\"reversed\",\"constrain\":\"domain\",\"title\":{\"text\":\"Position\"}},\"coloraxis\":{\"colorscale\":[[0.0,\"rgb(247,251,255)\"],[0.125,\"rgb(222,235,247)\"],[0.25,\"rgb(198,219,239)\"],[0.375,\"rgb(158,202,225)\"],[0.5,\"rgb(107,174,214)\"],[0.625,\"rgb(66,146,198)\"],[0.75,\"rgb(33,113,181)\"],[0.875,\"rgb(8,81,156)\"],[1.0,\"rgb(8,48,107)\"]]},\"margin\":{\"t\":60}},                        {\"responsive\": true}                    ).then(function(){\n",
       "                            \n",
       "var gd = document.getElementById('966ceaba-5cf4-47a0-853d-6aa7f9a77be2');\n",
       "var x = new MutationObserver(function (mutations, observer) {{\n",
       "        var display = window.getComputedStyle(gd).display;\n",
       "        if (!display || display === 'none') {{\n",
       "            console.log([gd, 'removed!']);\n",
       "            Plotly.purge(gd);\n",
       "            observer.disconnect();\n",
       "        }}\n",
       "}});\n",
       "\n",
       "// Listen for the removal of the full notebook cells\n",
       "var notebookContainer = gd.closest('#notebook-container');\n",
       "if (notebookContainer) {{\n",
       "    x.observe(notebookContainer, {childList: true});\n",
       "}}\n",
       "\n",
       "// Listen for the clearing of the current output cell\n",
       "var outputEl = gd.closest('.output');\n",
       "if (outputEl) {{\n",
       "    x.observe(outputEl, {childList: true});\n",
       "}}\n",
       "\n",
       "                        })                };                            </script>        </div>\n",
       "</body>\n",
       "</html>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import plotly.express as px\n",
    "px.imshow(cache[\"resid_post\", -1].detach().cpu().numpy()[0],\n",
    "color_continuous_scale=\"Blues\", labels={\"x\":\"Residual Stream\", \"y\":\"Position\"}, y=[str(i) for i in input]).show(\"colab\" if IN_COLAB else \"\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python38",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15 (default, Nov 24 2022, 15:19:38) \n[GCC 11.2.0]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "d6970b9bc1f22c1555ce2e3aef3e9bc8c56c5727cd75cae357902c75ead4068e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
